{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tuning ViT Classification models using AdapterPlus\n",
    "\n",
    "In this vision tutorial, we will show how to fine-tune ViT Image models using the `AdapterPlus` Config, which is a bottleneck adapter using the parameters as defined in the `AdapterPlus` paper. For more information on bottleneck adapters, you can visit our basic [tutorial](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) and our docs [page](https://docs.adapterhub.ml/methods#bottleneck-adapters).\n",
    "\n",
    "You can find the link to the `AdapterPlus` paper by Steitz and Roth [here](https://openaccess.thecvf.com/content/CVPR2024/papers/Steitz_Adapters_Strike_Back_CVPR_2024_paper.pdf) and their GitHub page [here](https://github.com/visinf/adapter_plus)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installation\n",
    "\n",
    "Before we can get started, we need to ensure the proper packages are installed. Here's a breakdown of what we need:\n",
    "\n",
    "- `adapters` to load the model and the adapter configuration\n",
    "- `accelerate` for training optimization\n",
    "- `evaluate` for metric computation and model evaluation\n",
    "- `datasets` to import the datasets for training and evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -qq -U adapters datasets accelerate torchvision evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset\n",
    "\n",
    "For this tutorial, we will be using a light-weight image dataset `cifar100`, which contains 60,000 images with 100 classes. We use the `datasets` library to directly import the dataset and split it into its training and evaluation datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "num_classes = 100\n",
    "train_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"train\")\n",
    "eval_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"test\")\n",
    "\n",
    "train_dataset.set_format(\"torch\")\n",
    "eval_dataset.set_format(\"torch\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this tutorial, we will be using the fine_label, to match the number of classes (100) that were used in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset[0].keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now initialize our `ViT` image processor to convert the images into a more friendly format. It will also apply transformations to each image in order to improve the performance during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = 'google/vit-base-patch16-224-in21k'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import ViTImageProcessor\n",
    "processor = ViTImageProcessor.from_pretrained(model_name_or_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll print out the processor here in order to get an idea of what types of transformations it is applying onto the iamge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Preprocessing\n",
    "\n",
    "We will pre-process every image as defined in the `processor` above, and add the `label` key which contains our labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_image(example):\n",
    "  image = processor(example[\"img\"], return_tensors='pt')\n",
    "  image[\"label\"] = example[\"fine_label\"]\n",
    "  return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = train_dataset.map(preprocess_image)\n",
    "eval_dataset = eval_dataset.map(preprocess_image)\n",
    "#remove uneccessary columns\n",
    "train_dataset = train_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])\n",
    "eval_dataset = eval_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining a Datacollator\n",
    "\n",
    "We'll be using a very simple custom datacollator to help us combine multiple data samples into one batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any\n",
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class DataCollator:\n",
    "  processor : Any\n",
    "  def __call__(self, inputs):\n",
    "\n",
    "    pixel_values = [input[\"pixel_values\"].squeeze() for input in inputs]\n",
    "    labels = [input[\"label\"] for input in inputs]\n",
    "\n",
    "    pixel_values = torch.stack(pixel_values)\n",
    "    labels = torch.stack(labels)\n",
    "    return {\n",
    "        'pixel_values': pixel_values,\n",
    "        'labels': labels,\n",
    "    }\n",
    "\n",
    "data_collator = DataCollator(processor = processor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the `ViT` model and the `AdapterPlusConfig`\n",
    "\n",
    "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`.\n",
    "\n",
    "#### Important Note\n",
    "\n",
    "Please note that some configurations of the adapters parameters `original_ln_after`, `original_ln_before`, and \n",
    "`residual_before_ln` may result in performance issues when training. \n",
    "\n",
    "In the general case:\n",
    "\n",
    "1) At least one of `original_ln_before` or `original_ln_after` should be set to `True` in order to ensure that the original residual\n",
    "    connection from pre-training is preserved. \n",
    "2) If `original_ln_after` is set to `False`, `residual_before_ln` must also be set to `False` to ensure convergence during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from adapters import ViTAdapterModel\n",
    "from adapters import AdapterPlusConfig\n",
    "\n",
    "model = ViTAdapterModel.from_pretrained(model_name_or_path)\n",
    "config = AdapterPlusConfig()\n",
    "\n",
    "model.add_adapter(\"adapterplus_config\", config)\n",
    "model.add_image_classification_head(\"adapterplus_config\", num_labels=num_classes)\n",
    "model.train_adapter(\"adapterplus_config\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Name                     Architecture         #Param      %Param  Active   Train\n",
      "--------------------------------------------------------------------------------\n",
      "adapterplus_config       bottleneck          165,984       0.192       1       1\n",
      "--------------------------------------------------------------------------------\n",
      "Full model                                86,389,248     100.000               0\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "print(model.adapter_summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation Metrics\n",
    "\n",
    "We'll use accuracy as our main metric to evaluate the perforce of the reft model on the `cifar100` dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "accuracy = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metrics(p):\n",
    "  return accuracy.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training\n",
    "\n",
    "Now we are ready to train our model. The same set of hyper-parameters that were used in the original paper will be re-used, except for the number of training epochs that the model will be trained on. You can always adjust the number of epochs yourself or any other hyperparameter in the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from adapters import AdapterTrainer\n",
    "from transformers import TrainingArguments\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./training_results',\n",
    "    eval_strategy='epoch',\n",
    "    learning_rate=10e-3,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    num_train_epochs=5,\n",
    "    weight_decay=10e-4,\n",
    "    report_to = \"none\",\n",
    "    remove_unused_columns=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='7820' max='7820' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7820/7820 2:44:24, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.523500</td>\n",
       "      <td>0.351386</td>\n",
       "      <td>0.897700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.264200</td>\n",
       "      <td>0.329842</td>\n",
       "      <td>0.906500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.184700</td>\n",
       "      <td>0.324673</td>\n",
       "      <td>0.911400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.135400</td>\n",
       "      <td>0.347025</td>\n",
       "      <td>0.909100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.114700</td>\n",
       "      <td>0.349541</td>\n",
       "      <td>0.910900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.074300</td>\n",
       "      <td>0.370867</td>\n",
       "      <td>0.909600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.058100</td>\n",
       "      <td>0.373732</td>\n",
       "      <td>0.912400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.039300</td>\n",
       "      <td>0.376220</td>\n",
       "      <td>0.913900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.028200</td>\n",
       "      <td>0.381238</td>\n",
       "      <td>0.913300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.021600</td>\n",
       "      <td>0.380825</td>\n",
       "      <td>0.912400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=7820, training_loss=0.13645367293101748, metrics={'train_runtime': 9867.7777, 'train_samples_per_second': 50.67, 'train_steps_per_second': 0.792, 'total_flos': 3.9121684697088e+19, 'train_loss': 0.13645367293101748, 'epoch': 10.0})"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer = AdapterTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    tokenizer=processor,\n",
    "    compute_metrics = compute_metrics\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference\n",
    "\n",
    "Now, we'll use our `adapters` trained model to classify some new images!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import Softmax\n",
    "#select a random sample from the evaluation dataset\n",
    "image = eval_dataset.select([0])\n",
    "logits = model(image['pixel_values'].squeeze(0))\n",
    "softmax = Softmax(0)\n",
    "prediction = torch.argmax(softmax(logits.logits.squeeze()))\n",
    "\n",
    "prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our prediction is the 49th class which corresponds to the 'mountain' label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving the adapter model\n",
    "\n",
    "If you would like to save your model or push it to HuggingFace, you can always do so with the below code. Make sure to sign in before you do"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_adapter_to_hub(\n",
    "    \"cifar100-adapterplus_config\",\n",
    "    \"adapterplus_config\",\n",
    "    datasets_tag=\"cifar100\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "adapter_hub",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
